# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import re
import tensorflow as tf
import numpy as np
import argparse
import facenet
from npu_bridge.npu_init import *
from skimage.transform import resize
import os
import sys
from tensorflow.python.ops import data_flow_ops
from sklearn import metrics
from scipy.optimize import brentq
from scipy import interpolate
from scipy.interpolate import interp1d
from scipy.io import loadmat
from scipy import spatial
import matplotlib.pyplot as plt
import random
import math
import scipy
import shutil

MAT_folder = '/home/ma-user/modelarts/inputs/data_url_0/gt_mat'
img_folder = '/home/ma-user/modelarts/inputs/data_url_0/MOT17Det/train'
temp_folder = '/home/ma-user/modelarts/outputs/train_url_0/temp'
triplet_model = '/home/ma-user/modelarts/inputs/data_url_0/model/20211206-213038/'
res_dir = '/home/ma-user/modelarts/outputs/train_url_0/bin/'

bbox_size = 182
max_length = 64
feature_size = 4 + 512
batch_size = 32
num_classes = 2
margin = 0.15
sample_prob = [600, 105, 837, 525, 654, 900, 750]
lr = 1e-3


def main():
    total_batch_x, total_batch_y = generate_data(feature_size, max_length, batch_size * 10, MAT_folder,
                                                 img_folder)
    total_batch_x = interp_batch(total_batch_x)

    # delete temp folder
    shutil.rmtree(temp_folder)

    remove_idx = []
    for k in range(len(total_batch_x)):
        if np.sum(total_batch_x[k, 0, :, 1]) == 0:
            remove_idx.append(k)
    total_batch_x = np.delete(total_batch_x, np.array(remove_idx), axis=0)
    total_batch_y = np.delete(total_batch_y, np.array(remove_idx), axis=0)
    print(len(total_batch_y))

    total_batch_x[:, 4:, :, 0] = 10 * total_batch_x[:, 4:, :, 0]
    temp_X = np.copy(total_batch_x)
    temp_Y = np.copy(total_batch_y)
    idx = np.arange(total_batch_x.shape[0])
    np.random.shuffle(idx)
    for k in range(len(idx)):
        total_batch_x[idx[k], :, :, :] = temp_X[k, :, :, :]
        total_batch_y[idx[k], :] = temp_Y[k, :]
    num_batch = int(np.ceil(len(total_batch_y) / batch_size))

    # shuffle 4 times
    acc = []
    for kk in range(num_batch):
        temp_batch_size = batch_size
        if kk == num_batch - 1:
            temp_batch_size = len(total_batch_y) - batch_size * (num_batch - 1)
        batch_x = total_batch_x[kk * batch_size:kk * batch_size + temp_batch_size, :, :, :]
        batch_y = total_batch_y[kk * batch_size:kk * batch_size + temp_batch_size, :]

        x = np.zeros((temp_batch_size, 1, max_length, 1))
        y = np.zeros((temp_batch_size, 1, max_length, 1))
        w = np.zeros((temp_batch_size, 1, max_length, 1))
        h = np.zeros((temp_batch_size, 1, max_length, 1))
        ap = np.zeros((temp_batch_size, feature_size - 4, max_length, 1))
        mask_1 = np.zeros((temp_batch_size, 1, max_length, 2))
        mask_2 = np.zeros((temp_batch_size, feature_size - 4, max_length, 2))

        x[:, 0, :, 0] = batch_x[:, 0, :, 0]
        y[:, 0, :, 0] = batch_x[:, 1, :, 0]
        w[:, 0, :, 0] = batch_x[:, 2, :, 0]
        h[:, 0, :, 0] = batch_x[:, 3, :, 0]

        ap[:, :, :, 0] = batch_x[:, 4:, :, 0]

        mask_1[:, 0, :, :] = batch_x[:, 0, :, 1:]
        mask_2[:, :, :, :] = batch_x[:, 4:, :, 1:]

        temp_path = res_dir + '{}/'.format(str(kk))
        if not os.path.exists(temp_path):
            os.makedirs(temp_path)
        x.tofile(temp_path + 'x.bin')
        y.tofile(temp_path + 'y.bin')
        w.tofile(temp_path + 'w.bin')
        h.tofile(temp_path + 'h.bin')
        ap.tofile(temp_path + 'ap.bin')
        mask_1.tofile(temp_path + 'mask_1.bin')
        mask_2.tofile(temp_path + 'mask_2.bin')
        batch_y.tofile(temp_path + 'batch_y.bin')
    print("+++++++++++++++++++++++++++++finished++++++++++++++++++++++++++++++")




# In[3]:
def draw_traj(x, mask_1):
    fig, ax = plt.subplots()
    ax.plot(x, color=[0.5, 0.5, 0.5], marker='o', linestyle='None')
    t1 = np.where(mask_1[:, 0] == 1)[0]
    t2 = np.where(mask_1[:, 1] == 1)[0]
    ax.plot(t1, x[mask_1[:, 0] == 1], color=[0.2, 0.6, 0.86], marker='o', linestyle='None')
    ax.plot(t2, x[mask_1[:, 1] == 1], color=[0.18, 0.8, 0.44], marker='o', linestyle='None')
    ax.axhline(y=0, color='k')
    ax.axvline(x=0, color='k')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    plt.xlim(0, 64)
    y_range = np.max(x) - np.min(x)
    plt.ylim(np.min(x) - y_range / 50, np.max(x) + y_range / 20)
    plt.show()


def draw_fea_map(x):
    fig, ax = plt.subplots()
    ax.imshow(np.power(np.transpose(x), 0.2), cmap='gray')
    ax.set_aspect(0.1)
    plt.axis('off')
    # plt.xticks(range(8))
    plt.show()


def interp_batch(total_batch_x):
    interp_batch_x = total_batch_x.copy()
    N_batch = total_batch_x.shape[0]
    for n in range(N_batch):
        temp_idx = np.where(total_batch_x[n, 0, :, 1] == 1)[0]
        t1 = int(temp_idx[-1])
        temp_idx = np.where(total_batch_x[n, 0, :, 2] == 1)[0]
        t2 = int(temp_idx[0])
        if t2 - t1 <= 1:
            continue
        interp_t = np.array(range(t1 + 1, t2))
        for k in range(total_batch_x.shape[1]):
            # temp_std = np.std(total_batch_x[n,k,total_batch_x[n,k,:,0]!=0,0])
            temp_std1 = np.std(total_batch_x[n, k, total_batch_x[n, 0, :, 1] != 0, 0])
            temp_std2 = np.std(total_batch_x[n, k, total_batch_x[n, 0, :, 2] != 0, 0])
            x_p = [t1, t2]
            f_p = [total_batch_x[n, k, t1, 0], total_batch_x[n, k, t2, 0]]
            # *************************************
            # interp_batch_x[n,k,t1+1:t2,0] = np.interp(interp_t,x_p,f_p)+np.random.normal(0, temp_std, t2-t1-1)
            # *************************************
            interp_batch_x[n, k, t1 + 1:t2, 0] = np.interp(interp_t, x_p, f_p) + np.random.normal(0, (
                        temp_std1 + temp_std2) * 0.5, t2 - t1 - 1)
    return interp_batch_x


def num_str(num, length):
    cnt = 1
    temp = num
    while 1:
        temp = int(temp / 10)
        if temp > 0:
            cnt = cnt + 1
        else:
            break
    num_len = cnt
    for n in range(length - num_len):
        if n == 0:
            out_str = '0'
        else:
            out_str = out_str + '0'
    if length - num_len > 0:
        return out_str + str(num)
    else:
        return str(num)


def file_name(num, length):
    cnt = 1
    temp = num
    while 1:
        temp = int(temp / 10)
        if temp > 0:
            cnt = cnt + 1
        else:
            break
    num_len = cnt
    for n in range(length - num_len):
        if n == 0:
            out_str = '0'
        else:
            out_str = out_str + '0'
    if length - num_len > 0:
        return out_str + str(num)
    else:
        return str(num)


def evaluate(sess, iterator, image_paths_placeholder, labels_placeholder, phase_train_placeholder,
             batch_size_placeholder, control_placeholder,
             embeddings, labels, image_paths, batch_size, distance_metric):
    # Run forward pass to calculate embeddings
    # print('Runnning forward pass on LFW images')

    use_flipped_images = False
    use_fixed_image_standardization = False
    use_random_rotate = False
    use_radnom_crop = False
    # Enqueue one epoch of image paths and labels
    nrof_embeddings = len(image_paths)  # nrof_pairs * nrof_images_per_pair
    nrof_flips = 2 if use_flipped_images else 1
    nrof_images = nrof_embeddings * nrof_flips
    labels_array = np.expand_dims(np.arange(0, nrof_images), 1)
    image_paths_array = np.expand_dims(np.repeat(np.array(image_paths), nrof_flips), 1)
    control_array = np.zeros_like(labels_array, np.int32)

    if use_fixed_image_standardization:
        control_array += np.ones_like(labels_array) * facenet.FIXED_STANDARDIZATION
    if use_flipped_images:
        # Flip every second image
        control_array += (labels_array % 2) * facenet.FLIP
    if use_random_rotate:
        control_array += facenet.RANDOM_ROTATE
    if use_radnom_crop:
        control_array += facenet.RANDOM_CROP

    sess.run(iterator.initializer, {batch_size_placeholder: batch_size, image_paths_placeholder: image_paths_array,
                                    labels_placeholder: labels_array, control_placeholder: control_array})

    embedding_size = int(embeddings.get_shape()[1])
    assert nrof_images % batch_size == 0, 'The number of LFW images must be an integer multiple of the LFW batch size'
    nrof_batches = nrof_images // batch_size
    emb_array = np.zeros((nrof_images, embedding_size))
    lab_array = np.zeros((nrof_images,))
    for i in range(nrof_batches):
        emb, lab = sess.run([embeddings, labels], feed_dict={phase_train_placeholder: False})
        lab_array[lab] = lab
        emb_array[lab, :] = emb
        if i % 10 == 9:
            print('.', end='')
            sys.stdout.flush()
    # import pdb; pdb.set_trace()
    # np.savetxt("emb_array.csv", emb_array, delimiter=",")
    return emb_array


def split_track(X_2d, Y_2d, W_2d, H_2d, V_2d, img_size, obj_id, noise_scale, connect_thresh):
    err_flag = 0
    # 取出指定轨迹的所有w的值
    part_W_mat = W_2d[:, obj_id]
    # 找到所有w不为0的轨迹，即该物体出现的帧号
    non_zero_idx = np.where(part_W_mat > 0)[0]
    # 确保语义正确
    if len(non_zero_idx) <= 1 or np.max(non_zero_idx) - np.min(non_zero_idx) + 1 != len(non_zero_idx):
        err_flag = 1
        return [], [], err_flag

    # 该轨迹出现的frame_id和结束的frame_id
    st_fr = np.min(non_zero_idx)
    end_fr = np.max(non_zero_idx)

    bbox_tracklet = []

    #
    v_flag = 1
    rand_num = random.uniform(0.0, 1.0)
    if rand_num < 0.5 or len(V_2d) == 0:  # 0.5
        v_flag = 0

    # v_flag = 0
    for k in range(st_fr, end_fr + 1):
        # 生成四个小于noise_scale的随机数，
        rand_num = np.zeros((1, 4))
        for kk in range(4):
            while 1:
                rand_num[0, kk] = np.random.normal(0, 0.05, size=1)[0]
                if abs(rand_num[0, kk]) < noise_scale:
                    break

        # k frame的id
        # 不要超过图片边界
        x0 = max(0, X_2d[k, obj_id])
        x1 = min(img_size[0] - 1, X_2d[k, obj_id] + W_2d[k, obj_id])

        y0 = max(0, Y_2d[k, obj_id])
        y1 = min(img_size[1] - 1, Y_2d[k, obj_id] + H_2d[k, obj_id])

        w0 = max(1, x1 - x0)
        h0 = max(1, y1 - y0)

        xmin = max(0, x0 + rand_num[0, 0] * w0)
        xmin = min(img_size[0] - 2, xmin)

        ymin = max(0, y0 + rand_num[0, 1] * h0)
        ymin = min(img_size[1] - 2, ymin)

        ww = max(1, w0 + rand_num[0, 2] * w0)
        hh = max(1, h0 + rand_num[0, 3] * h0)

        xmax = min(xmin + ww, img_size[0] - 1)
        ymax = min(ymin + hh, img_size[1] - 1)

        ww = max(1, xmax - xmin + 1)
        hh = max(1, ymax - ymin + 1)
        # print(rand_num)
        # 随机从该帧的该轨迹的图片中裁剪一部分
        temp_bbox = [int(k), int(xmin), int(ymin), int(xmax), int(ymax), float(V_2d[k, obj_id])]

        if k == st_fr:
            bbox = []

        # 如果一个bbox，除去重合的剩余部分的比值，小于0.5
        # 说明此刻该物体的轨迹不值得被研究（样本像素点太少），
        # 即认为它的轨迹到此中断了，只将之前的一段添加到它的轨迹框图列表中。
        if temp_bbox[-1] < 0.5:
            if len(bbox) == 0:
                continue
            else:
                bbox_tracklet.append(np.array(bbox))
                bbox = []
                continue

        rand_num = random.uniform(0.0, 1.0)
        if v_flag == 0:
            temp_connect = connect_thresh
        else:
            temp_connect = connect_thresh * math.sqrt(V_2d[k, obj_id])

        # 如果达到链接片段的标准，则将当前帧中的bbox添加到当前的轨迹段中
        if rand_num < temp_connect:
            bbox.append(temp_bbox)
            # 如果已经分析完整个视频帧，则将当前轨迹段添加到轨迹段集合中
            if k == end_fr and len(bbox) != 0:
                bbox_tracklet.append(np.array(bbox))
                # bbox_num.append(len(bbox))
        else:
            # 当前轨迹段，到此结束，将当前轨迹段入档，并重新维护一个新的轨迹段，并将当前bbox添加到新的轨迹段上
            if len(bbox) != 0:
                bbox_tracklet.append(np.array(bbox))
                # bbox_num.append(len(bbox))
            bbox = []
            bbox.append(temp_bbox)

        '''    
        if k==st_fr:
            bbox = []
            bbox.append(temp_bbox)
        else:
            rand_num = random.uniform(0.0,1.0)
            if v_flag==0:
                temp_connect = connect_thresh
            else:
                temp_connect = connect_thresh*math.sqrt(V_2d[k,obj_id])

            if rand_num<temp_connect:
                bbox.append(temp_bbox)
                if k==end_fr:
                    bbox_tracklet.append(np.array(bbox))
                    bbox_num.append(len(bbox))
            else:
                bbox_tracklet.append(np.array(bbox))
                bbox_num.append(len(bbox))
                bbox = []
                bbox.append(temp_bbox)
         '''
    # import pdb; pdb.set_trace()
    # 计算该物品总共的有几段连续的轨迹
    t_interval = np.zeros((len(bbox_tracklet), 2))
    # 对于每一段轨迹
    for k in range(len(bbox_tracklet)):
        # print(bbox_tracklet[k])
        # bbox_tracklet bbox temp_bbox
        # 记录起始帧数
        t_interval[k, 0] = bbox_tracklet[k][0, 0]
        # 记录结束帧数
        t_interval[k, 1] = bbox_tracklet[k][-1, 0]
    return bbox_tracklet, t_interval, err_flag


def generate_data(feature_size, max_length, batch_size, MAT_folder, img_folder):
    noise_scale = 0.15  # 0.15
    # connect_thresh = 0.95
    connect_thresh = np.random.uniform(0.6, 1)
    # sample_p = np.array([0.0852,0.1996,0.2550,0.0313,0.0854,0.1546,0.1890])
    # 从不同视频文件中的采样比例
    sample_p = np.array(sample_prob)
    sample_p = list(sample_p / sum(sample_p))

    # load mat files
    Mat_paths = os.listdir(MAT_folder)
    choose_idx = np.random.choice(len(Mat_paths), size=batch_size, p=sample_p)
    Mat_files = []
    for n in range(batch_size):
        temp_path = MAT_folder + '/' + Mat_paths[choose_idx[n]]
        temp_mat_file = loadmat(temp_path)
        Mat_files.append(temp_mat_file)

    # X = np.zeros((batch_size,1,max_length,1+4+512))
    X = np.zeros((batch_size, feature_size, max_length, 3))
    Y = np.zeros((batch_size, 2))
    crop_bbox = []

    # positive
    # 选出batch_size/2个正样本对
    for n in range(int(batch_size / 2)):
        # print(n)
        # frame总数
        fr_num = Mat_files[n]['gtInfo'][0][0][0].shape[0]
        # tracklet总数
        id_num = Mat_files[n]['gtInfo'][0][0][0].shape[1]

        # 将标签置为1 表示正样本对？？？？
        Y[n, 0] = 1

        X_2d = Mat_files[n]['gtInfo'][0][0][0]
        Y_2d = Mat_files[n]['gtInfo'][0][0][1]
        W_2d = Mat_files[n]['gtInfo'][0][0][2]
        H_2d = Mat_files[n]['gtInfo'][0][0][3]

        #########################################
        # 为啥要这么干？？？？？
        X_2d = X_2d - margin * W_2d
        Y_2d = Y_2d - margin * H_2d
        W_2d = (1 + 2 * margin) * W_2d
        H_2d = (1 + 2 * margin) * H_2d
        ##########################################

        if len(Mat_files[n]['gtInfo'][0][0]) <= 4:
            V_2d = []
        else:
            V_2d = Mat_files[n]['gtInfo'][0][0][4]
        if len(Mat_files[n]['gtInfo'][0][0]) == 6:
            img_size = Mat_files[n]['gtInfo'][0][0][5][0]
            # img_size = [800,600]，此处是将其解析出来成为数字元组
            temp_size = re.findall(r"\d+\.?\d*", img_size)
            img_size = [int(temp_size[0]), int(temp_size[1])]
        else:
            img_size = [1920, 1080]
        # V_2d = []

        # import pdb; pdb.set_trace()

        while 1:
            # 随机生成一个大于等于0，小于id_num的单个数字
            obj_id = np.random.randint(id_num, size=1)[0]

            # bbox_tracklet：被拆分的轨迹段的数组， t_interval 每一段轨迹段的起始帧数与结束帧数， err_flag是否出现错误
            bbox_tracklet, t_interval, err_flag = split_track(X_2d, Y_2d, W_2d, H_2d, V_2d, img_size, obj_id,
                                                              noise_scale, connect_thresh)
            if err_flag == 1:
                continue

            cand_pairs = []
            if len(bbox_tracklet) <= 1:
                continue
            for k1 in range(len(bbox_tracklet) - 1):
                for k2 in range(k1 + 1, len(bbox_tracklet)):
                    if t_interval[k1, 0] + max_length > t_interval[k2, 0]:
                        t_dist = t_interval[k2, 0] - t_interval[k1, 1]
                        cand_pairs.append([k1, k2, t_dist])
            if len(cand_pairs) == 0:
                continue

            cand_pairs = np.array(cand_pairs)
            rand_num = np.random.rand(1)[0]
            # print(rand_num)
            if rand_num < 0.7:
                select_p = np.exp(-np.power(cand_pairs[:, 2], 2) / 100)
                select_p = select_p / sum(select_p)
                # print(select_p)
                pair_idx = np.random.choice(len(cand_pairs), size=1, p=select_p)[0]
            else:
                pair_idx = np.random.randint(len(cand_pairs), size=1)[0]
            select_pair = cand_pairs[pair_idx]
            select_pair = select_pair.astype(int)

            abs_fr_t1 = int(t_interval[select_pair[0], 0])
            abs_fr_t2 = int(t_interval[select_pair[0], 1])
            abs_fr_t3 = int(t_interval[select_pair[1], 0])
            abs_fr_t4 = int(min(abs_fr_t1 + max_length - 1, t_interval[select_pair[1], 1]))

            t1 = 0
            t2 = abs_fr_t2 - abs_fr_t1
            t3 = abs_fr_t3 - abs_fr_t1
            t4 = abs_fr_t4 - abs_fr_t1

            # mask
            X[n, :, t1:t2 + 1, 1] = 1
            X[n, :, t3:t4 + 1, 2] = 1

            # X
            X[n, 0, t1:t2 + 1, 0] = 0.5 * (bbox_tracklet[select_pair[0]][:, 1] + bbox_tracklet[select_pair[0]][:, 3]) / \
                                    img_size[0]
            X[n, 0, t3:t4 + 1, 0] = 0.5 * (
                        bbox_tracklet[select_pair[1]][0:t4 - t3 + 1, 1] + bbox_tracklet[select_pair[1]][0:t4 - t3 + 1,
                                                                          3]) / img_size[0]

            # Y
            X[n, 1, t1:t2 + 1, 0] = 0.5 * (bbox_tracklet[select_pair[0]][:, 2] + bbox_tracklet[select_pair[0]][:, 4]) / \
                                    img_size[1]
            X[n, 1, t3:t4 + 1, 0] = 0.5 * (
                        bbox_tracklet[select_pair[1]][0:t4 - t3 + 1, 2] + bbox_tracklet[select_pair[1]][0:t4 - t3 + 1,
                                                                          4]) / img_size[1]

            # W
            X[n, 2, t1:t2 + 1, 0] = (bbox_tracklet[select_pair[0]][:, 3] - bbox_tracklet[select_pair[0]][:, 1]) / \
                                    img_size[0]
            X[n, 2, t3:t4 + 1, 0] = (bbox_tracklet[select_pair[1]][0:t4 - t3 + 1, 3] - bbox_tracklet[select_pair[1]][
                                                                                       0:t4 - t3 + 1, 1]) / img_size[0]

            # H
            X[n, 3, t1:t2 + 1, 0] = (bbox_tracklet[select_pair[0]][:, 4] - bbox_tracklet[select_pair[0]][:, 2]) / \
                                    img_size[1]
            X[n, 3, t3:t4 + 1, 0] = (bbox_tracklet[select_pair[1]][0:t4 - t3 + 1, 4] - bbox_tracklet[select_pair[1]][
                                                                                       0:t4 - t3 + 1, 2]) / img_size[1]
            '''       
            plt.plot(X[n,0,:,0], 'ro')
            plt.show()
            plt.plot(X[n,1,:,0], 'ro')
            plt.show()
            plt.plot(X[n,2,:,0], 'ro')
            plt.show()
            plt.plot(X[n,3,:,0], 'ro')
            plt.show()
            plt.plot(X[n,0,:,1], 'ro')
            plt.show()
            plt.plot(X[n,0,:,2], 'ro')
            plt.show()    
            import pdb; pdb.set_trace() 
            '''

            # save all bbox
            temp_crop_bbox = np.concatenate(
                (bbox_tracklet[select_pair[0]], bbox_tracklet[select_pair[1]][0:t4 - t3 + 1, :]), axis=0)
            temp_crop_bbox = temp_crop_bbox.astype(int)
            crop_bbox.append(temp_crop_bbox)
            break

    # import pdb; pdb.set_trace()

    # negative
    for n in range(int(batch_size / 2), batch_size):
        fr_num = Mat_files[n]['gtInfo'][0][0][0].shape[0]
        id_num = Mat_files[n]['gtInfo'][0][0][0].shape[1]
        Y[n, 1] = 1

        X_2d = Mat_files[n]['gtInfo'][0][0][0]
        Y_2d = Mat_files[n]['gtInfo'][0][0][1]
        W_2d = Mat_files[n]['gtInfo'][0][0][2]
        H_2d = Mat_files[n]['gtInfo'][0][0][3]

        #########################################
        X_2d = X_2d - margin * W_2d
        Y_2d = Y_2d - margin * H_2d
        W_2d = (1 + 2 * margin) * W_2d
        H_2d = (1 + 2 * margin) * H_2d
        ##########################################

        if len(Mat_files[n]['gtInfo'][0][0]) <= 4:
            V_2d = []
        else:
            V_2d = Mat_files[n]['gtInfo'][0][0][4]
        if len(Mat_files[n]['gtInfo'][0][0]) == 6:
            img_size = Mat_files[n]['gtInfo'][0][0][5][0]
            temp_size = re.findall(r"\d+\.?\d*", img_size)
            img_size = [int(temp_size[0]), int(temp_size[1])]
        else:
            img_size = [1920, 1080]
        # V_2d = []
        # img_size = [1920,1080]

        # check candidate obj pairs
        # pair_mat = np.zeros((id_num,id_num))
        cand_idx_pairs = []
        for n1 in range(id_num - 1):
            for n2 in range(n1 + 1, id_num):
                cand_fr1 = np.where(W_2d[:, n1] > 0)[0]
                cand_fr2 = np.where(W_2d[:, n2] > 0)[0]
                if max(cand_fr1[0], cand_fr2[0]) < min(cand_fr1[-1], cand_fr2[-1]):
                    cand_idx_pairs.append([n1, n2])
                    # pair_mat[n1,n2] = 1

        # cand_pairs = np.nonzero(pair_mat)
        while 1:
            #
            if len(cand_idx_pairs) == 0:
                import pdb;
                pdb.set_trace()
            pair_idx = np.random.randint(len(cand_idx_pairs), size=1)[0]
            obj_id1 = cand_idx_pairs[pair_idx][0]
            obj_id2 = cand_idx_pairs[pair_idx][1]
            # import pdb; pdb.set_trace()

            part_W_mat1 = W_2d[:, obj_id1]
            non_zero_idx1 = np.where(part_W_mat1 > 0)[0]
            part_W_mat2 = W_2d[:, obj_id2]
            non_zero_idx2 = np.where(part_W_mat2 > 0)[0]
            if len(non_zero_idx1) == 0 or len(non_zero_idx2) == 0 or \
                    max(non_zero_idx1) + max_length < min(non_zero_idx2) or min(non_zero_idx1) > max(non_zero_idx2):
                continue

            bbox_tracklet1, t_interval1, err_flag = split_track(X_2d, Y_2d, W_2d, H_2d, V_2d, img_size, obj_id1,
                                                                noise_scale, connect_thresh)
            if err_flag == 1:
                continue
            bbox_tracklet2, t_interval2, err_flag = split_track(X_2d, Y_2d, W_2d, H_2d, V_2d, img_size, obj_id2,
                                                                noise_scale, connect_thresh)
            if err_flag == 1:
                continue

            cand_pairs = []
            if len(bbox_tracklet1) <= 1 or len(bbox_tracklet2) <= 1:
                continue
            for k1 in range(len(bbox_tracklet1)):
                for k2 in range(len(bbox_tracklet2)):
                    if t_interval1[k1, 0] + max_length > t_interval2[k2, 0] and t_interval1[k1, 1] < t_interval2[k2, 0]:
                        t_dist = t_interval2[k2, 0] - t_interval1[k1, 1]
                        cand_pairs.append([k1, k2, t_dist])
            if len(cand_pairs) == 0:
                continue

            cand_pairs = np.array(cand_pairs)
            rand_num = np.random.rand(1)[0]
            # print(rand_num)
            if rand_num < 0.7:
                select_p = np.exp(-np.power(cand_pairs[:, 2], 2) / 100)
                select_p = select_p / sum(select_p)
                # print(select_p)
                pair_idx = np.random.choice(len(cand_pairs), size=1, p=select_p)[0]
            else:
                pair_idx = np.random.randint(len(cand_pairs), size=1)[0]
            select_pair = cand_pairs[pair_idx]
            select_pair = select_pair.astype(int)

            abs_fr_t1 = int(t_interval1[select_pair[0], 0])
            abs_fr_t2 = int(t_interval1[select_pair[0], 1])
            abs_fr_t3 = int(t_interval2[select_pair[1], 0])
            abs_fr_t4 = int(min(abs_fr_t1 + max_length - 1, t_interval2[select_pair[1], 1]))

            t1 = 0
            t2 = abs_fr_t2 - abs_fr_t1
            t3 = abs_fr_t3 - abs_fr_t1
            t4 = abs_fr_t4 - abs_fr_t1

            # mask
            X[n, :, t1:t2 + 1, 1] = 1
            X[n, :, t3:t4 + 1, 2] = 1
            # X[n,4:,t1:t2+1,1] = bbox_tracklet1[select_pair[0]][:,5]
            # X[n,4:,t3:t4+1,2] = bbox_tracklet2[select_pair[1]][0:t4-t3+1,5]

            # X
            X[n, 0, t1:t2 + 1, 0] = 0.5 * (
                        bbox_tracklet1[select_pair[0]][:, 1] + bbox_tracklet1[select_pair[0]][:, 3]) / img_size[0]
            X[n, 0, t3:t4 + 1, 0] = 0.5 * (
                        bbox_tracklet2[select_pair[1]][0:t4 - t3 + 1, 1] + bbox_tracklet2[select_pair[1]][0:t4 - t3 + 1,
                                                                           3]) / img_size[0]

            # Y
            X[n, 1, t1:t2 + 1, 0] = 0.5 * (
                        bbox_tracklet1[select_pair[0]][:, 2] + bbox_tracklet1[select_pair[0]][:, 4]) / img_size[1]
            X[n, 1, t3:t4 + 1, 0] = 0.5 * (
                        bbox_tracklet2[select_pair[1]][0:t4 - t3 + 1, 2] + bbox_tracklet2[select_pair[1]][0:t4 - t3 + 1,
                                                                           4]) / img_size[1]

            # W
            X[n, 2, t1:t2 + 1, 0] = (bbox_tracklet1[select_pair[0]][:, 3] - bbox_tracklet1[select_pair[0]][:, 1]) / \
                                    img_size[0]
            X[n, 2, t3:t4 + 1, 0] = (bbox_tracklet2[select_pair[1]][0:t4 - t3 + 1, 3] - bbox_tracklet2[select_pair[1]][
                                                                                        0:t4 - t3 + 1, 1]) / img_size[0]

            # H
            X[n, 3, t1:t2 + 1, 0] = (bbox_tracklet1[select_pair[0]][:, 4] - bbox_tracklet1[select_pair[0]][:, 2]) / \
                                    img_size[1]
            X[n, 3, t3:t4 + 1, 0] = (bbox_tracklet2[select_pair[1]][0:t4 - t3 + 1, 4] - bbox_tracklet2[select_pair[1]][
                                                                                        0:t4 - t3 + 1, 2]) / img_size[1]

            # save all bbox
            temp_crop_bbox = np.concatenate(
                (bbox_tracklet1[select_pair[0]], bbox_tracklet2[select_pair[1]][0:t4 - t3 + 1, :]), axis=0)
            temp_crop_bbox = temp_crop_bbox.astype(int)
            crop_bbox.append(temp_crop_bbox)
            break

    # crop data to a temp folder
    if not os.path.exists(temp_folder):
        os.makedirs(temp_folder)
    all_paths = []
    for n in range(batch_size):
        temp_all_path = []
        seq_name = Mat_paths[choose_idx[n]][:-4]
        img_path = img_folder + '/' + seq_name + '/img1/'
        track_name = file_name(n + 1, 4)
        save_path = temp_folder + '/' + track_name
        if not os.path.exists(save_path):
            os.makedirs(save_path)

        for k in range(len(crop_bbox[n])):
            fr_id = crop_bbox[n][k, 0] + 1
            temp_img_path = img_path + file_name(fr_id, 6) + '.jpg'
            img = plt.imread(temp_img_path)
            bbox_img = img[crop_bbox[n][k, 2]:crop_bbox[n][k, 4], crop_bbox[n][k, 1]:crop_bbox[n][k, 3], :]
            # import pdb; pdb.set_trace()
            # bbox_img = scipy.misc.imresize(bbox_img, size=(bbox_size,bbox_size))
            bbox_img = resize(bbox_img, output_shape=(bbox_size, bbox_size))
            # bbox_img = np.array(Image.fromarray(bbox_img).astype(int)).resize((bbox_size,bbox_size))
            bbox_img_path = save_path + '/' + file_name(k, 4) + '.png'
            temp_all_path.append(bbox_img_path)
            plt.imsave(bbox_img_path, bbox_img)
        all_paths.append(temp_all_path)
    # import pdb; pdb.set_trace()

    f_image_size = 160
    distance_metric = 0
    with tf.Graph().as_default():
        # config = tf.ConfigProto()
        # custom_op = config.graph_options.rewrite_options.custom_optimizers.add()
        # custom_op.name = "NpuOptimizer"
        # config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF
        with tf.Session() as sess:
            # import pdb; pdb.set_trace()
            image_paths_placeholder = tf.placeholder(tf.string, shape=(None, 1), name='image_paths')
            labels_placeholder = tf.placeholder(tf.int32, shape=(None, 1), name='labels')
            batch_size_placeholder = tf.placeholder(tf.int64, name='batch_size')
            control_placeholder = tf.placeholder(tf.int32, shape=(None, 1), name='control')
            phase_train_placeholder = tf.placeholder(tf.bool, name='phase_train')

            image_size = (f_image_size, f_image_size)
            dataset = tf.data.Dataset.from_tensor_slices(
                (image_paths_placeholder, labels_placeholder, control_placeholder))
            dataset = dataset.map(
                lambda filenames, label, control: facenet.parser_data(filenames, label, control, image_size))
            dataset = dataset.unbatch()
            dataset = dataset.batch(batch_size=batch_size_placeholder, drop_remainder=True)
            iterator = tf.data.make_initializable_iterator(dataset)
            image_batch, label_batch = iterator.get_next()

            # Load the model
            input_map = {'image_batch': image_batch, 'label_batch': label_batch, 'phase_train': phase_train_placeholder}
            facenet.load_model(triplet_model, input_map=input_map)

            # Get output tensor
            embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
            #
            # coord = tf.train.Coordinator()
            # tf.train.start_queue_runners(coord=coord, sess=sess)

            for n in range(len(all_paths)):
                # print(n)
                lfw_batch_size = len(all_paths[n])
                emb_array = evaluate(sess, iterator, image_paths_placeholder, labels_placeholder,
                                     phase_train_placeholder,
                                     batch_size_placeholder, control_placeholder, embeddings, label_batch, all_paths[n],
                                     lfw_batch_size, distance_metric)

                if X[n, 4:, X[n, 0, :, 1] + X[n, 0, :, 2] > 0.5, 0].shape[0] != emb_array.shape[0]:
                    aa = 0
                    import pdb;
                    pdb.set_trace()

                X[n, 4:, X[n, 0, :, 1] + X[n, 0, :, 2] > 0.5, 0] = emb_array

    # import pdb; pdb.set_trace()
    return X, Y


# In[4]:


if __name__ == '__main__':
    main()